import torch
import math
from torch import nn
from config.model_config import ContextEncoderConfig
from config import GlobalConfig
from util import length_to_mask


class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, dropout=0.1, max_len=512):  # max_len: 最大索引长度
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)  # max_len * m
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # max_len * 1:
        # div_term = exp(-log(10000^(i/d))) = exp((i/d) * (-log(10000))) = exp(i * (-log(10000)/d))
        # 注意：i必须是偶数，即(0,2,...,128)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))
        # p is index of position, i is index of embedding
        # i为偶数：PosEnc(p,i)=sin(p/10000^(i/d))   i为奇数：PosEnc(p,i)=cos(p/10000^(i-1/d))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # max_len * 1 * m
        # 网络进行.cuda()时会加载两种参数，一种是模型参数nn.Parameter，在optim.step中更新；另一种是buffer，在forward中更新.
        self.register_buffer('pe', pe)

    def forward(self, x):  # x: l * b * m
        x = x + self.pe[:x.size(0), :]  # pe[:x.size(0), :] : l * 1 * m
        return x  # x: l * b * m


class ContextEncoder(nn.Module):
    def __init__(self, config: ContextEncoderConfig):
        super(ContextEncoder, self).__init__()
        self.transformer_layer = nn.TransformerEncoderLayer(config.input_size, config.num_head, config.dim_feedforward, 
                                                            config.dropout, config.activation).to(GlobalConfig.device)
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, config.num_layers).to(GlobalConfig.device)
        self.linear = nn.Linear(config.input_size, config.embed_size).to(GlobalConfig.device)
        self.position_embedding = PositionalEncoding(config.embed_size, config.dropout, config.max_len).to(GlobalConfig.device)

    def forward(self, context):
        """

        Args:
            context: (dialog_context_size, batch_size,
                      text_feat_size + image_feat_size)

        Returns:
            context vector: (batch_size,
                             hidden_size * num_layers * num_directions)

        """

        batch_size = context.size(1)
        # 默认按第0维堆叠，即将对话历史这一维堆叠
        context = torch.stack([self.linear(x) for x in context]).to(GlobalConfig.device)
        # (dialog_context_size, batch_size, text_feat_size + image_feat_size)
        # attention_mask = length_to_mask(context.size(0).to(GlobalConfig.device)) == False
        # context = self.position_embedding(context)
        h_n = self.transformer_encoder(context)
        h_n = h_n[0, :, :]
        # 取第一个标记的输出结果作为句子的联合语义表示
        # (num_layers, batch_size, hidden_size)

        output = h_n.transpose(0, 1)
        output = output.contiguous().view(batch_size, -1)
        # (batch_size, hidden_size * num_layers)
        return output
